import pandas as pd
import statsmodels.formula.api as sm

from cat_rfs.code.utils.table import Table

SETTINGS = {'main': {'beta_var': 'beta_sheet',
                     'yield_var': 'sheet_yield_usd',
                     'table_file': 'table4'},
            'trace': {'beta_var': 'beta_act',
                      'yield_var': 'act_yield_usd',
                      'table_file': 'table4_trace'},
            'loss': {'beta_var': 'beta_loss',
                     'yield_var': 'sheet_yield_usd',
                     'table_file': 'table4_loss'}}


def print_table4(df, output='console', setting='main'):
    beta = SETTINGS[setting]['beta_var']
    yield_ = SETTINGS[setting]['yield_var']
    tname = SETTINGS[setting]['table_file']

    # %% Calculate peril weights
    perils = list(df.loc[~df['perils'].str.contains(','), 'perils'].unique())
    perils = ['el_' + s.replace(' ', '_').lower() for s in perils]
    df['totel'] = df[perils].sum(axis='columns')

    for n, peril in enumerate(perils):
        df[f'{peril}_weight'] = df[peril] / df['totel']
        df[f'{peril}_amt'] = df[f'{peril}_weight'] * df['size']
        df[f'{peril}_amt_group'] = df.groupby(['date'])[f'{peril}_amt'].transform('sum')
        df[f'{peril}_weighted'] = df[f'{peril}_weight'] * df[f'{peril}_amt_group']

    df['weight_peril'] = df[[f'{peril}_weighted' for peril in perils]].sum(axis='columns') / df[
        [f'{peril}_amt_group' for peril in perils]].sum(axis='columns')

    if yield_ == 'act_yield_usd':
        df = df[(df['date'] >= '12/31/2004') & (df['date'] <= '12/31/2018')]

    df['er'] = df[yield_] - df['expected_loss'] / 100 - df['rf'] / 100
    df = df.dropna(subset=[beta, 'er']).copy()

    df['er_m'] = df['er'] * df['size']
    df2 = df.groupby(['date'])[['er_m', 'size']].sum().reset_index()
    df2['er_m'] = df2['er_m'] / df2['size']
    df = df.drop(columns=['er_m'])
    df = df.merge(df2[['date', 'er_m']], on='date')

    df['er_model'] = df[beta] * df['er_m']

    dfc = df.copy()

    data = pd.DataFrame([['$\\lambda_0$', '', '$\\lambda_{cat}$', '', '$\\lambda_{el}$', '', '$\\lambda_{w}$', '', '$\\lambda_{el\\times w}$', '']],
                        columns=['alpha', 'alpha_se', 'lambda', 'lambda_se', 'lambda_el', 'lambda_el_se', 'lambda_w',
                                 'lambda_w_se', 'lambda_elw', 'lambda_elw_se'])
    other_stats = [['$N$', '$R^2$']]

    # Main
    df = dfc.copy()
    res = []
    for t in df['date'].unique():
        df2 = df[(df['date'] == t)].reset_index()
        ols = sm.ols(formula=f'er ~ {beta}', data=df2).fit()

        a = ols.params['Intercept'] * 100

        la = ols.params[beta] * 100

        res.append([pd.to_datetime(t).year, a, la, '', '', ols.rsquared, ols.nobs])

    res_df = pd.DataFrame(res, columns=['year', 'alpha', 'lambda_', 'lambda_el', 'lambda_w', 'Rsquared', 'N'])

    lags = 4
    ols = sm.ols('alpha ~ 1', data=res_df).fit(cov_type='HAC', cov_kwds={'maxlags': lags}, use_t=True)
    a = ols.params['Intercept']
    a_se = ols.bse['Intercept']

    ols = sm.ols('lambda_ ~ 1', data=res_df).fit(cov_type='HAC', cov_kwds={'maxlags': lags}, use_t=True)
    la = ols.params['Intercept']
    la_se = ols.bse['Intercept']

    N = res_df.shape[0]
    R2 = res_df['Rsquared'].mean()

    data = data.append(pd.DataFrame([[a, a_se, la, la_se, '', '', '', '', '', '']],
                                    columns=['alpha', 'alpha_se', 'lambda', 'lambda_se', 'lambda_el', 'lambda_el_se',
                                             'lambda_w', 'lambda_w_se', 'lambda_elw', 'lambda_elw_se']),
                       ignore_index=True, sort=False)

    other_stats.append([N, R2])

    # EL
    df = dfc.copy()
    res = []
    for t in df['date'].unique():
        df2 = df[(df['date'] == t)].reset_index()
        ols = sm.ols(formula='er ~ expected_loss', data=df2).fit()

        a = ols.params['Intercept'] * 100

        las = ols.params[1] * 100

        res.append([pd.to_datetime(t).year, a, '', las, '', ols.rsquared, ols.nobs])

    res_df = pd.DataFrame(res, columns=['year', 'alpha', 'lambda_', 'lambda_el', 'lambda_w', 'Rsquared', 'N'])

    ols = sm.ols('alpha ~ 1', data=res_df).fit(cov_type='HAC', cov_kwds={'maxlags': lags}, use_t=True)
    a = ols.params['Intercept']
    a_se = ols.bse['Intercept']

    ols = sm.ols('lambda_el ~ 1', data=res_df).fit(cov_type='HAC', cov_kwds={'maxlags': lags}, use_t=True)
    lael = ols.params['Intercept']
    lael_se = ols.bse['Intercept']

    N = res_df.shape[0]
    R2 = res_df['Rsquared'].mean()

    data = data.append(pd.DataFrame([[a, a_se, '', '', lael, lael_se, '', '', '', '']],
                                    columns=['alpha', 'alpha_se', 'lambda', 'lambda_se', 'lambda_el', 'lambda_el_se',
                                             'lambda_w', 'lambda_w_se', 'lambda_elw', 'lambda_elw_se']),
                       ignore_index=True, sort=False)

    other_stats.append([N, R2])

    # W
    df = dfc.copy()
    res = []
    for t in df['date'].unique():
        df2 = df[(df['date'] == t)].reset_index()
        ols = sm.ols(formula='er ~ weight_peril', data=df2).fit()

        a = ols.params['Intercept'] * 100

        las = ols.params[1] * 100

        res.append([pd.to_datetime(t).year, a, '', '', las, ols.rsquared, ols.nobs])

    res_df = pd.DataFrame(res, columns=['year', 'alpha', 'lambda_', 'lambda_el', 'lambda_w', 'Rsquared', 'N'])

    ols = sm.ols('alpha ~ 1', data=res_df).fit(cov_type='HAC', cov_kwds={'maxlags': lags}, use_t=True)
    a = ols.params['Intercept']
    a_se = ols.bse['Intercept']

    ols = sm.ols('lambda_w ~ 1', data=res_df).fit(cov_type='HAC', cov_kwds={'maxlags': lags}, use_t=True)
    law = ols.params['Intercept']
    law_se = ols.bse['Intercept']

    N = res_df.shape[0]
    R2 = res_df['Rsquared'].mean()

    data = data.append(pd.DataFrame([[a, a_se, '', '', '', '', law, law_se, '', '']],
                                    columns=['alpha', 'alpha_se', 'lambda', 'lambda_se', 'lambda_el', 'lambda_el_se',
                                             'lambda_w', 'lambda_w_se', 'lambda_elw', 'lambda_elw_se']),
                       ignore_index=True, sort=False)

    other_stats.append([N, R2])

    # Both
    df = dfc.copy()
    res = []
    for t in df['date'].unique():
        df2 = df[(df['date'] == t)].reset_index()
        ols = sm.ols(formula=f'er ~ expected_loss + weight_peril', data=df2).fit()

        a = ols.params['Intercept'] * 100

        lael = ols.params[1] * 100

        law = ols.params[2] * 100

        res.append([pd.to_datetime(t).year, a, '', lael, law, ols.rsquared, ols.nobs])

    res_df = pd.DataFrame(res, columns=['year', 'alpha', 'lambda_', 'lambda_el', 'lambda_w', 'Rsquared', 'N'])

    ols = sm.ols('alpha ~ 1', data=res_df).fit(cov_type='HAC', cov_kwds={'maxlags': lags}, use_t=True)
    a = ols.params['Intercept']
    a_se = ols.bse['Intercept']

    ols = sm.ols('lambda_el ~ 1', data=res_df).fit(cov_type='HAC', cov_kwds={'maxlags': lags}, use_t=True)
    lael = ols.params['Intercept']
    lael_se = ols.bse['Intercept']

    ols = sm.ols('lambda_w ~ 1', data=res_df).fit(cov_type='HAC', cov_kwds={'maxlags': lags}, use_t=True)
    law = ols.params['Intercept']
    law_se = ols.bse['Intercept']

    N = res_df.shape[0]
    R2 = res_df['Rsquared'].mean()

    data = data.append(pd.DataFrame([[a, a_se, '', '', lael, lael_se, law, law_se, '', '']],
                                    columns=['alpha', 'alpha_se', 'lambda', 'lambda_se', 'lambda_el', 'lambda_el_se',
                                             'lambda_w', 'lambda_w_se', 'lambda_elw', 'lambda_elw_se']),
                       ignore_index=True, sort=False)

    other_stats.append([N, R2])

    # el : W
    df = dfc.copy()
    res = []
    for t in df['date'].unique():
        df2 = df[(df['date'] == t)].reset_index()
        ols = sm.ols(formula='er ~ expected_loss : weight_peril', data=df2).fit()

        a = ols.params['Intercept'] * 100

        las = ols.params[1] * 100

        res.append([pd.to_datetime(t).year, a, '', '', las, ols.rsquared, ols.nobs])

    res_df = pd.DataFrame(res, columns=['year', 'alpha', 'lambda_', 'lambda_el', 'lambda_elw', 'Rsquared', 'N'])

    ols = sm.ols('alpha ~ 1', data=res_df).fit(cov_type='HAC', cov_kwds={'maxlags': lags}, use_t=True)
    a = ols.params['Intercept']
    a_se = ols.bse['Intercept']

    ols = sm.ols('lambda_elw ~ 1', data=res_df).fit(cov_type='HAC', cov_kwds={'maxlags': lags}, use_t=True)
    laelw = ols.params['Intercept']
    laelw_se = ols.bse['Intercept']

    N = res_df.shape[0]
    R2 = res_df['Rsquared'].mean()

    data = data.append(pd.DataFrame([[a, a_se, '', '', '', '', '', '', laelw, laelw_se]],
                                    columns=['alpha', 'alpha_se', 'lambda', 'lambda_se', 'lambda_el', 'lambda_el_se',
                                             'lambda_w', 'lambda_w_se', 'lambda_elw', 'lambda_elw_se']),
                       ignore_index=True, sort=False)

    other_stats.append([N, R2])

    # Beta + Both
    df = dfc.copy()
    res = []
    for t in df['date'].unique():
        df2 = df[(df['date'] == t)].reset_index()
        ols = sm.ols(formula=f'er ~ {beta} + expected_loss + weight_peril', data=df2).fit()

        a = ols.params['Intercept'] * 100

        la = ols.params[beta] * 100

        lael = ols.params[2] * 100

        law = ols.params[3] * 100

        res.append([pd.to_datetime(t).year, a, la, lael, law, '', ols.rsquared, ols.nobs])

    res_df = pd.DataFrame(res, columns=['year', 'alpha', 'lambda_', 'lambda_el', 'lambda_w', 'lambda_elw',
                                        'Rsquared', 'N'])

    ols = sm.ols('alpha ~ 1', data=res_df).fit(cov_type='HAC', cov_kwds={'maxlags': lags}, use_t=True)
    a = ols.params['Intercept']
    a_se = ols.bse['Intercept']

    ols = sm.ols('lambda_ ~ 1', data=res_df).fit(cov_type='HAC', cov_kwds={'maxlags': lags}, use_t=True)
    la = ols.params['Intercept']
    la_se = ols.bse['Intercept']

    ols = sm.ols('lambda_el ~ 1', data=res_df).fit(cov_type='HAC', cov_kwds={'maxlags': lags}, use_t=True)
    lael = ols.params['Intercept']
    lael_se = ols.bse['Intercept']

    ols = sm.ols('lambda_w ~ 1', data=res_df).fit(cov_type='HAC', cov_kwds={'maxlags': lags}, use_t=True)
    law = ols.params['Intercept']
    law_se = ols.bse['Intercept']

    N = res_df.shape[0]
    R2 = res_df['Rsquared'].mean()

    data = data.append(pd.DataFrame([[a, a_se, la, la_se, lael, lael_se, law, law_se, '', '']],
                                    columns=['alpha', 'alpha_se', 'lambda', 'lambda_se', 'lambda_el', 'lambda_el_se',
                                             'lambda_w', 'lambda_w_se', 'lambda_elw', 'lambda_elw_se']),
                       ignore_index=True, sort=False)

    other_stats.append([N, R2])

    # Beta + Interaction
    df = dfc.copy()
    res = []
    for t in df['date'].unique():
        df2 = df[(df['date'] == t)].reset_index()
        ols = sm.ols(formula=f'er ~ {beta} + expected_loss : weight_peril', data=df2).fit()

        a = ols.params['Intercept'] * 100

        la = ols.params[beta] * 100

        laelw = ols.params[2] * 100

        res.append([pd.to_datetime(t).year, a, la, '', '', laelw, ols.rsquared, ols.nobs])

    res_df = pd.DataFrame(res, columns=['year', 'alpha', 'lambda_', 'lambda_el', 'lambda_w', 'lambda_elw', 'Rsquared', 'N'])

    ols = sm.ols('alpha ~ 1', data=res_df).fit(cov_type='HAC', cov_kwds={'maxlags': lags}, use_t=True)
    a = ols.params['Intercept']
    a_se = ols.bse['Intercept']

    ols = sm.ols('lambda_ ~ 1', data=res_df).fit(cov_type='HAC', cov_kwds={'maxlags': lags}, use_t=True)
    la = ols.params['Intercept']
    la_se = ols.bse['Intercept']

    ols = sm.ols('lambda_elw ~ 1', data=res_df).fit(cov_type='HAC', cov_kwds={'maxlags': lags}, use_t=True)
    laelw = ols.params['Intercept']
    laelw_se = ols.bse['Intercept']

    N = res_df.shape[0]
    R2 = res_df['Rsquared'].mean()

    data = data.append(pd.DataFrame([[a, a_se, la, la_se, '', '', '', '', laelw, laelw_se]],
                                    columns=['alpha', 'alpha_se', 'lambda', 'lambda_se', 'lambda_el', 'lambda_el_se',
                                             'lambda_w', 'lambda_w_se', 'lambda_elw', 'lambda_elw_se']),
                       ignore_index=True, sort=False)

    other_stats.append([N, R2])

    T = Table(8, width='auto', justs=['l', 'd', 'd', 'd', 'd', 'd', 'd', 'd'],
              caption='Exposure versus characteristics',
              label=tname, footer=None)
    T.add_panel(data.T, headers=['', '(1)', '(2)', '(3)', '(4)', '(5)', '(6)', '(7)'],
                regression_specs={'stars': (0.1, 0.05, 0.01), 'dof': (N - 1)})
    T.add_panel(other_stats)

    if output == 'paper':
        T.print_table(f'cat_rfs/output/tables/{tname}.tex')
    else:
        T.print_table()

    pass
